{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 机器学习工程师纳米学位\n",
    "## 深度学习\n",
    "## 项目：搭建一个数字识别项目\n",
    "\n",
    "在此文件中，我们提供给你了一个模板，以便于你根据项目的要求一步步实现要求的功能，进而完成整个项目。如果你认为需要导入另外的一些代码，请确保你正确导入了他们，并且包含在你的提交文件中。以**'练习'**开始的标题表示接下来你将开始实现你的项目。注意有一些练习是可选的，并且用**'可选'**标记出来了。\n",
    "\n",
    "在此文件中，有些示例代码已经提供给你，但你还需要实现更多的功能让项目成功运行。除非有明确要求，你无须修改任何已给出的代码。以'练习'开始的标题表示接下来的代码部分中有你必须要实现的功能。每一部分都会有详细的指导，需要实现的部分也会在注释中以'TODO'标出。请仔细阅读所有的提示！\n",
    "\n",
    "除了实现代码外，你还必须回答一些与项目和你的实现有关的问题。每一个需要你回答的问题都会以**'问题 X'**为标题。请仔细阅读每个问题，并且在问题后的**'回答'**文字框中写出完整的答案。我们将根据你对问题的回答和撰写代码所实现的功能来对你提交的项目进行评分。\n",
    "\n",
    ">**注意：** Code 和 Markdown 区域可通过 **Shift + Enter** 快捷键运行。此外，Markdown可以通过双击进入编辑模式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 连接 mnist 的字符来合成数据\n",
    "\n",
    "你可以通过连接[MNIST](http://yann.lecun.com/exdb/mnist/)的字符来合成数据来训练这个模型。为了快速导入数据集，我们可以使用 [Keras Datasets](https://keras.io/datasets/#mnist-database-of-handwritten-digits) [中文文档](http://keras-cn.readthedocs.io/en/latest/other/datasets/#mnist)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 载入 mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.datasets import mnist\n",
    "\n",
    "(X_raw, y_raw), (X_raw_test, y_raw_test) = mnist.load_data()\n",
    "\n",
    "n_train, n_test = X_raw.shape[0], X_raw_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 可视化 mnist\n",
    "\n",
    "我们可以通过 matplotlib 来可视化我们的原始数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "for i in range(15):\n",
    "    plt.subplot(3, 5, i+1)\n",
    "    index = random.randint(0, n_train-1)\n",
    "    plt.title(str(y_raw[index]))\n",
    "    plt.imshow(X_raw[index], cmap='gray')\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 练习：合成数据\n",
    "\n",
    "你需要随机取随机张图片，然后将它们拼接成新的图片。\n",
    "\n",
    "你需要设置20%的数据作为验证集，以保证模型没有过拟合。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "n_class, n_len, width, height = 11, 5, 28, 28\n",
    "\n",
    "def generate_dataset(X, y):\n",
    "    X_len = X.shape[0]\n",
    "    \n",
    "    X_gen = np.zeros((X_len, height, width*n_len, 1), dtype=np.uint8)\n",
    "    y_gen = [np.zeros((X_len, n_class), dtype=np.uint8) for i in range(n_len)]\n",
    "    # TODO: 随机取1~5个数字，并拼接成新的图片\n",
    "    \n",
    "    \n",
    "    return X_gen, y_gen\n",
    "\n",
    "X_raw_train, X_raw_valid, y_raw_train, y_raw_valid = None\n",
    "\n",
    "X_train, y_train = generate_dataset(X_raw_train, y_raw_train)\n",
    "X_valid, y_valid = generate_dataset(X_raw_valid, y_raw_valid)\n",
    "X_test, y_test = generate_dataset(X_raw_test, y_raw_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 显示生成的图片\n",
    "for i in range(15):\n",
    "    plt.subplot(5, 3, i+1)\n",
    "    index = random.randint(0, n_test-1)\n",
    "    title = ''\n",
    "    for j in range(n_len):\n",
    "        title += str(np.argmax(y_test[j][index])) + ','\n",
    "    \n",
    "    plt.title(title)\n",
    "    plt.imshow(X_test[index][:,:,0], cmap='gray')\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 问题 1\n",
    "_你是如何合成数据集的？为什么要分训练集，验证集和测试集？_\n",
    "\n",
    "**回答：**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 练习：设计并测试一个模型架构\n",
    "\n",
    "设计并实现一个能够识别数字序列的深度学习模型。为了产生用于测试的合成数字序列，你可以进行如下的设置：比如，你可以限制一个数据序列最多五个数字，并在你的深度网络上使用五个分类器。同时，你有必要准备一个额外的“空白”的字符，以处理相对较短的数字序列。\n",
    "\n",
    "在思考这个问题的时候有很多方面可以考虑：\n",
    "\n",
    "- 你的模型可以基于深度神经网络或者是卷积神经网络。\n",
    "- 你可以尝试是否在每个分类器间共享权值。\n",
    "- 你还可以在深度神经网络中使用循环网络来替换其中的分类层，并且将数字序列里的数字一个一个地输出。\n",
    "\n",
    "在使用 Keras 搭建模型的时候，你可以使用 [函数式模型 API](http://keras-cn.readthedocs.io/en/latest/models/model/) 的方式来搭建多输出模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from keras.models import Model\n",
    "from keras.layers import *\n",
    "\n",
    "# TODO: 构建你的模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 问题 2\n",
    "_你为解决这个问题采取了什么技术？请详细介绍你使用的技术。_\n",
    "\n",
    "**回答：** \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 可视化你的网络模型\n",
    "\n",
    "参考链接：[可视化visualization](http://keras-cn.readthedocs.io/en/latest/other/visualization/)\n",
    "\n",
    "可以是保存成 PNG 格式显示，也可以直接使用 SVG 格式。 SVG 是矢量图，它的优点是可以无限放大。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.utils.vis_utils import plot_model, model_to_dot\n",
    "from IPython.display import Image, SVG\n",
    "\n",
    "# TODO: 可视化你的模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 问题 3\n",
    "_你最终的模型架构是什么样的？（什么类型的模型，层数，大小, 如何连接等）_\n",
    "\n",
    "**回答：**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 练习：训练你的网络模型\n",
    "\n",
    "训练你的模型时，需要设置训练集和验证集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: 训练你的模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 练习：计算你的模型准确率\n",
    "\n",
    "我们刚才得到了模型每个数字的准确率，现在让我们来计算整体准确率，按照完全预测正确数字序列的标准来计算。\n",
    "\n",
    "比如 1,2,3,10,10 预测成了 1,2,10,10,10 算错，而不是算对了80%。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def evaluate(model):\n",
    "    # TODO: 按照错一个就算错的规则计算准确率\n",
    "    \n",
    "\n",
    "evaluate(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 问题 4\n",
    "\n",
    "_你的模型准确率有多少？你觉得你的模型足以解决问题吗？_\n",
    "\n",
    "**回答：**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 预测值可视化\n",
    "\n",
    "我们将模型的预测结果和真实值画出来，观察真实效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(result):\n",
    "    # 将 one_hot 编码解码\n",
    "    resultstr = ''\n",
    "    for i in range(n_len):\n",
    "        resultstr += str(np.argmax(result[i])) + ','\n",
    "    return resultstr\n",
    "\n",
    "index = random.randint(0, n_test-1)\n",
    "y_pred = model.predict(X_test[index].reshape(1, height, width*n_len, 1))\n",
    "\n",
    "plt.title('real: %s\\npred:%s'%(get_result([y_test[x][index] for x in range(n_len)]), get_result(y_pred)))\n",
    "plt.imshow(X_test[index,:,:,0], cmap='gray')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存模型\n",
    "\n",
    "模型达到满意的效果以后，我们需要保存模型，以便下次调用。\n",
    "\n",
    "读取的方式也很简单：`model = load_model('model.h5')`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
